function [loss,post_mean_sigma2,trace_var_beta] = ridge_loss(Y,X,a0,b0,gamma)

%Compute Ridge loss 

k      = size(X,2);

n      = size(X,1);

lambda_n ...
       = (X'*X) + (gamma*k*eye(k));

%mu_n ... 
%       = inv(lambda_n)*(X'*y);
   
post_mean_sigma2 ...
       = (b0 + (.5* (Y'*Y - (Y'*X*(lambda_n^(-1))*X'*Y)) ))/(a0+(n/2)-1); 
   
trace_var_beta ...
       = trace(post_mean_sigma2.*(lambda_n^(-1))*(X'*X/n)); %*(X'*X/n)
   
loss   = ...
         post_mean_sigma2 + trace_var_beta;
end

